{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "provenance": [],
      "machine_shape": "hm",
      "gpuType": "A100",
      "include_colab_link": true
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "language_info": {
      "name": "python"
    },
    "accelerator": "GPU"
  },
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "view-in-github",
        "colab_type": "text"
      },
      "source": [
        "<a href=\"https://colab.research.google.com/github/Yogesh914/VideoGPT-plus/blob/main/VideoGPT%2B_Demo.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "# VideoGPT+ Demo 🎥"
      ],
      "metadata": {
        "id": "CP7--7iTuciZ"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "## Setup"
      ],
      "metadata": {
        "id": "Wcomhj6fukGQ"
      }
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "D_h0ivDItx21"
      },
      "outputs": [],
      "source": [
        "!git clone https://github.com/mbzuai-oryx/VideoGPT-plus"
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "%cd VideoGPT-plus"
      ],
      "metadata": {
        "id": "Ds8NftGDyBBW"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "!pip install huggingface_hub\n",
        "\n",
        "from huggingface_hub import notebook_login\n",
        "notebook_login()"
      ],
      "metadata": {
        "id": "J02UHN8WvU4g"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "!git clone https://huggingface.co/OpenGVLab/InternVideo2-Stage2_1B-224p-f4"
      ],
      "metadata": {
        "id": "I7qQvagisjQe"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "!mkdir OpenGVLab\n",
        "!mv InternVideo2-Stage2_1B-224p-f4 OpenGVLab/"
      ],
      "metadata": {
        "id": "TL5-SUvt1Ih7"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "!git clone https://huggingface.co/MBZUAI/VideoGPT-plus_Phi3-mini-4k"
      ],
      "metadata": {
        "id": "n1IEfAC8uMN3"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "!pip install shortuuid timm einops flash-attn decord mmengine ninja peft"
      ],
      "metadata": {
        "id": "S_UQhfdEuw4E"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "## Inference"
      ],
      "metadata": {
        "id": "kXx6QFlUumnQ"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "from tqdm import tqdm\n",
        "from videogpt_plus.conversation import conv_templates\n",
        "from videogpt_plus.model.builder import load_pretrained_model\n",
        "from videogpt_plus.mm_utils import tokenizer_image_token, get_model_name_from_path\n",
        "from eval.vcgbench.inference.ddp import *\n",
        "from eval.video_encoding import _get_rawvideo_dec\n",
        "import traceback\n",
        "import pandas as pd\n",
        "import torch\n",
        "import os\n",
        "\n",
        "# Disable parameter resetting for specific layers\n",
        "setattr(torch.nn.Linear, \"reset_parameters\", lambda self: None)\n",
        "setattr(torch.nn.LayerNorm, \"reset_parameters\", lambda self: None)\n",
        "\n",
        "# Configuration\n",
        "model_path = \"/content/VideoGPT-plus/VideoGPT-plus_Phi3-mini-4k/vcgbench\"\n",
        "model_base = \"microsoft/Phi-3-mini-4k-instruct\"\n",
        "video_path = \"path/to/video\"\n",
        "qs = \"PROMPT\"\n",
        "\n",
        "conv_mode = \"phi3_instruct\"\n",
        "stop_str = \"<|end|>\"\n",
        "temperature = 0.0\n",
        "\n",
        "# Load pretrained model and tokenizer\n",
        "model_path = os.path.expanduser(model_path)\n",
        "model_name = get_model_name_from_path(model_path)\n",
        "tokenizer, model, image_processor, context_len = load_pretrained_model(\n",
        "    model_path, model_base, model_name\n",
        ")\n",
        "\n",
        "# Configure model for multimodal use\n",
        "mm_use_im_start_end = getattr(model.config, \"mm_use_im_start_end\", False)\n",
        "mm_use_im_patch_token = getattr(model.config, \"mm_use_im_patch_token\", True)\n",
        "if mm_use_im_patch_token:\n",
        "    tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)\n",
        "if mm_use_im_start_end:\n",
        "    tokenizer.add_tokens(\n",
        "        [DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True\n",
        "    )\n",
        "model.resize_token_embeddings(len(tokenizer))\n",
        "\n",
        "# Load vision towers\n",
        "vision_tower = model.get_vision_tower()\n",
        "vision_tower.load_model(model.config.mm_vision_tower)\n",
        "video_processor = vision_tower.image_processor\n",
        "\n",
        "image_vision_tower = model.get_image_vision_tower()\n",
        "image_vision_tower.load_model()\n",
        "image_processor = image_vision_tower.image_processor\n",
        "\n",
        "# Move model to GPU\n",
        "model = model.to(\"cuda\")\n",
        "\n",
        "# Process video if it exists\n",
        "if os.path.exists(video_path):\n",
        "    video_frames, context_frames, slice_len = _get_rawvideo_dec(\n",
        "        video_path,\n",
        "        image_processor,\n",
        "        video_processor,\n",
        "        max_frames=NUM_FRAMES,\n",
        "        image_resolution=224,\n",
        "        num_video_frames=NUM_FRAMES,\n",
        "        num_context_images=NUM_CONTEXT_IMAGES,\n",
        "    )\n",
        "\n",
        "# Prepare query string with image tokens\n",
        "if model.config.mm_use_im_start_end:\n",
        "    qs = (\n",
        "        DEFAULT_IM_START_TOKEN\n",
        "        + DEFAULT_IMAGE_TOKEN * slice_len\n",
        "        + DEFAULT_IM_END_TOKEN\n",
        "        + \"\\n\"\n",
        "        + qs\n",
        "    )\n",
        "else:\n",
        "    qs = DEFAULT_IMAGE_TOKEN * slice_len + \"\\n\" + qs\n",
        "\n",
        "# Create conversation prompt\n",
        "conv = conv_templates[conv_mode].copy()\n",
        "conv.append_message(conv.roles[0], qs)\n",
        "conv.append_message(conv.roles[1], None)\n",
        "prompt = conv.get_prompt()\n",
        "\n",
        "input_ids = (\n",
        "    tokenizer_image_token(\n",
        "        prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors=\"pt\"\n",
        "    )\n",
        "    .unsqueeze(0)\n",
        "    .cuda()\n",
        ")\n",
        "\n",
        "# Generate output\n",
        "with torch.inference_mode():\n",
        "    output_ids = model.generate(\n",
        "        input_ids,\n",
        "        images=torch.stack(video_frames).half().cuda(),\n",
        "        context_images=torch.stack(context_frames).half().cuda(),\n",
        "        do_sample=True if temperature > 0 else False,\n",
        "        temperature=temperature,\n",
        "        top_p=None,\n",
        "        num_beams=1,\n",
        "        max_new_tokens=1024,\n",
        "        use_cache=True,\n",
        "    )\n",
        "\n",
        "# Validate output\n",
        "input_token_len = input_ids.shape[1]\n",
        "n_diff_input_output = (\n",
        "    (input_ids != output_ids[:, :input_token_len]).sum().item()\n",
        ")\n",
        "if n_diff_input_output > 0:\n",
        "    print(\n",
        "        f\"[Warning] {n_diff_input_output} output_ids are not the same as the input_ids\"\n",
        "    )\n",
        "\n",
        "# Decode and clean up output\n",
        "outputs = tokenizer.batch_decode(\n",
        "    output_ids[:, input_token_len:], skip_special_tokens=True\n",
        ")[0]\n",
        "outputs = outputs.strip()\n",
        "if outputs.endswith(stop_str):\n",
        "    outputs = outputs[: -len(stop_str)]\n",
        "outputs = outputs.strip()\n",
        "outputs = outputs.replace(\"<|end|>\", \"\")\n",
        "outputs = outputs.strip()\n",
        "\n",
        "print(outputs)"
      ],
      "metadata": {
        "id": "OUIabt9fuB3_"
      },
      "execution_count": null,
      "outputs": []
    }
  ]
}